#!/usr/bin/env python
# encoding: utf-8
'''
@author: Excelsiorly
@license: (C) Copyright 2021, All Rights Reserved.
@contact: excelsiorly@qq.com
@file: 004. 寻找两个正序数组的中位数.py
@time: 2021/12/15 20:01
@desc: https://leetcode-cn.com/problems/median-of-two-sorted-arrays/submissions/
> 给定两个大小分别为 m 和 n 的正序（从小到大）数组 nums1 和 nums2。请你找出并返回这两个正序数组的 中位数 。
@解题思路:
    1. 归并排序，t: O(m+n), s: O(m+n)
    2. 二分, 每次取到各个数组的第k//2-1个，比较后排除较小的那一部分，然后找第k小的数字转移成找第[k-被排除个数]小的数
        t:O(log(m+n)), s: O(1)
'''

# 合并排序O(m+n)
class Solution(object):
    def findMedianSortedArrays(self, nums1, nums2):
        """
        :type nums1: List[int]
        :type nums2: List[int]
        :rtype: float
        """
        nums = []
        # 归并排序
        i, j = 0, 0
        while i < len(nums1) and j < len(nums2):
            if nums1[i] < nums2[j]:
                nums.append(nums1[i])
                i += 1
            else:
                nums.append(nums2[j])
                j += 1

        if i < len(nums1):
            for t in range(i, len(nums1)):
                nums.append(nums1[t])

        if j < len(nums2):
            for t in range(j, len(nums2)):
                nums.append(nums2[t])

        n = len(nums)
        if n % 2 == 0:
            return (nums[n // 2 - 1] + nums[n // 2]) / 2
        else:
            return nums[n // 2]

# 二分
class Solution02(object):
    def findMedianSortedArrays(self, nums1, nums2):
        """
        :type nums1: List[int]
        :type nums2: List[int]
        :rtype: float
        """

        def getKthNum(k):
            m, n = len(nums1), len(nums2)
            idx1, idx2 = 0, 0
            while True:
                # 特殊情况
                # 1. nums1已经都被排除了, 只需返回num2中的现在的第k个
                if idx1 == m:
                    return nums2[idx2 + k - 1]
                # 2. nums2已经都被排除了, 只需返回num1中的现在的第k个
                if idx2 == n:
                    return nums1[idx1 + k - 1]
                # 3. 问题转换到找到剩余元素的最小值
                if k == 1:
                    return min(nums1[idx1], nums2[idx2])
                # 普通情况
                newIdx1 = min(idx1 + k // 2 - 1, m - 1)
                newIdx2 = min(idx2 + k // 2 - 1, n - 1)
                piv1, piv2 = nums1[newIdx1], nums2[newIdx2]
                # 若piv1<piv2，就可以排除num1中newIdx1前的元素，问题转化为找第[k-排除掉的元素个数]个元素
                if piv1 <= piv2:
                    k -= newIdx1 - idx1 + 1
                    idx1 = newIdx1 + 1
                else:
                    k -= newIdx2 - idx2 + 1
                    idx2 = newIdx2 + 1

        m, n = len(nums1), len(nums2)
        totalLength = m + n
        if totalLength % 2 == 1:
            return getKthNum((totalLength + 1) // 2)
        else:
            return (getKthNum(totalLength // 2) + getKthNum(totalLength // 2 + 1)) / 2

if __name__ == '__main__':
    nums1 = [3,4]
    nums2 = [1,2]
    mid = Solution02().findMedianSortedArrays(nums1, nums2)
    print(mid)